Skip to content

Comments

NVFP4 primary weights#2690

Closed
WanZzzzzz wants to merge 47 commits intoNVIDIA:mainfrom
WanZzzzzz:fp4_primary_weights
Closed

NVFP4 primary weights#2690
WanZzzzzz wants to merge 47 commits intoNVIDIA:mainfrom
WanZzzzzz:fp4_primary_weights

Conversation

@WanZzzzzz
Copy link

@WanZzzzzz WanZzzzzz commented Feb 19, 2026

Description

This PR adds NVFP4 partial cast support for distributed training with ZeRO/FSDP optimizers. It enables efficient casting of FP32 master weight shards to NVFP4 model weights with coordinated scaling across data parallel ranks, while minimizing CPU overhead in large-scale training.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

This PR introduces NVFP4 partial cast infrastructure and optimizations for distributed training:

NVFP4 Partial Cast Kernel (nvfp4_2d_partial_cast)

  • Implements nibble-accurate partial updates for NVFP4 tensors in distributed settings
  • Supports two-level NVFP4 scaling: global FP32 scale + per-block FP8 E4M3 scale

NVFP4 Transpose Kernel (nvfp4_transpose)

  • Custom transpose kernel for nibble-packed NVFP4 data with shared memory optimization
  • Uses vectorized uint2 loads/stores with 64×64 tiles for efficient memory access
  • Handles nibble repacking during transpose (unlike FP8 byte transpose)
  • Enables columnwise data generation for GEMM operations after rowwise AllGather

Fused Scale Kernel (nvfp4_fused_scale)

  • Fuses per-block scale computation, global amax copy, and FP8 scale expansion into a single kernel
  • Eliminates multiple kernel launches and avoids D2H transfers by accepting tensor pointers
  • Reduces kernel launch overhead in the critical path

Multi-Tensor Dispatch Pattern

  • C++-side loop dispatch for NVFP4 multi-tensor operations
  • Reduces Python–C++ transition overhead compared to per-tensor Python loops
  • Collects metadata in Python and executes batched operations in C++ wrappers

CPU Overhead Optimizations

  • Batched dtype conversion via torch.cat / torch.split
  • Replaced torch.zeros() with torch.empty() for immediately written buffers
  • Consolidated metadata collection and allocation phases
  • Optimized bucket partitioning for expert parallel buffers

Scale Computation Improvements

  • Fixed floating-point precision mismatch between Python and CUDA
  • Uses FP32 constants consistent with CUDA arithmetic
  • Ensures bitwise-identical results between partial and full quantization paths

New Public API

cast_master_weights_to_nvfp4()

  • Casts FP32 master weights to NVFP4 model weights
  • Handles global and per-block amax reduction across data parallel groups
  • Designed for low CPU overhead in distributed training loops

Testing

Test Description
test_nvfp4_transpose_kernel Verifies correctness for nibble-packed transpose
test_nvfp4_partial_cast_matches_full Multi-GPU: partial cast + all-gather equals full cast
test_single_gpu_partial_cast_vs_full Single-GPU: offset=0 partial cast matches reference quantizer
_test_cast_master_weights_to_nvfp4 500-iteration training loop with bitwise-identical loss

This feature also passed numeric validation in GPT-3 training on the corresponding Megatron-Core branch:

https://gitlab-master.nvidia.com/qiyuw/megatron-lm-all/-/tree/fp4_primary_opt?ref_type=heads

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@WanZzzzzz WanZzzzzz changed the title Fp4 primary weights NVFP4 primary weights Feb 19, 2026
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 19, 2026

Greptile Summary

This PR implements NVFP4 partial cast infrastructure for distributed training with ZeRO/FSDP optimizers, enabling efficient FP32→NVFP4 conversion of weight shards with coordinated scaling across data parallel ranks.

Key additions:

  • cast_master_weights_to_nvfp4() public API for distributed weight casting
  • NVFP4 2D partial cast kernel with nibble-accurate updates for weight shards
  • Custom transpose kernel handling nibble repacking (unlike FP8 byte transpose)
  • Fused scale kernel combining per-block scale computation, global amax copy, and FP8 expansion
  • Multi-tensor dispatch pattern reducing Python-C++ overhead via batched operations
  • CPU optimizations: batched dtype conversion, torch.empty() instead of torch.zeros(), consolidated metadata collection

Critical issue:

  • Undefined variable new_rowwise_data in replace_raw_data() function (line 54) will cause runtime error

Testing:

  • 500-iteration training loop validates bitwise-identical loss between NVFP4 and reference
  • Multi-GPU tests verify partial cast + all-gather equals full cast
  • NVFP4 transpose kernel correctness validated
  • Feature validated in GPT-3 training on Megatron-Core branch

Performance impact:

  • Eliminates multiple kernel launches through fused operations
  • Avoids D2H transfers by accepting tensor pointers
  • Reduces kernel launch overhead in critical training path

Confidence Score: 4/5

  • This PR is mostly safe to merge after fixing the critical undefined variable bug
  • The implementation is well-designed with comprehensive testing (500-iteration training loop, multi-GPU validation) and has been validated in GPT-3 training. The CUDA kernels are well-documented with proper error checking. However, there's a critical syntax error that will cause immediate runtime failure if replace_raw_data() is called with an NVFP4Tensor. Once that's fixed, the code quality is high.
  • Pay close attention to transformer_engine/pytorch/tensor/utils.py which contains the undefined variable bug that must be fixed before merge

Important Files Changed

Filename Overview
transformer_engine/pytorch/tensor/utils.py Adds cast_master_weights_to_nvfp4() API with batched multi-tensor operations. Contains critical undefined variable bug in replace_raw_data().
transformer_engine/common/recipe/nvfp4.cu Implements NVFP4 CUDA kernels (partial cast, transpose, fused scale). Well-documented with clear kernel design comments and proper error checking.
transformer_engine/pytorch/csrc/extensions/nvfp4_2d_partial_cast.cpp C++ wrappers for NVFP4 partial cast operations with multi-tensor batching support. Proper input validation and stream handling.
tests/pytorch/distributed/test_cast_master_weights_to_fp8.py Comprehensive test suite with 500-iteration training loop validating bitwise-identical loss between NVFP4 and reference implementations.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[Master Weights FP32<br/>Sharded across DP ranks] --> B[Batched dtype conversion<br/>torch.cat + to + split]
    B --> C[Multi-tensor partial amax<br/>nvfp4_multi_tensor_compute_partial_amax]
    C --> D[Per-block amax<br/>16x16 tiles]
    C --> E[Global amax<br/>per tensor]
    D --> F[AllReduce MAX<br/>block amax across DP]
    E --> G[AllReduce MAX<br/>global amax across DP]
    F --> H[Fused scale kernel<br/>nvfp4_fused_scale]
    G --> H
    H --> I[Compute per-block decode scale<br/>block_amax * 448 / global_amax]
    H --> J[Expand to row-level<br/>Convert to FP8 E4M3]
    H --> K[Copy global amax to target]
    I --> L[Multi-tensor partial cast<br/>nvfp4_multi_tensor_2d_partial_cast]
    J --> L
    L --> M[NVFP4 packed data<br/>2 nibbles per byte<br/>nibble-accurate updates]
    M --> N[AllGather<br/>Gather full model weights]
    N --> O[Multi-tensor columnwise creation<br/>nvfp4_multi_tensor_create_columnwise]
    O --> P[NVFP4 transpose<br/>Nibble repacking]
    O --> Q[Scale transpose<br/>Rowwise to columnwise]
    P --> R[Ready for GEMM<br/>Columnwise data + scales]
    Q --> R
Loading

Last reviewed commit: 687c8b6

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

10 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

elif isinstance(tensor, NVFP4Tensor):
old_rowwise = tensor._rowwise_data
assert old_rowwise.dtype == new_raw_data.dtype, "The data types of raw data don't match"
new_rowwise_data.detach().copy_(old_rowwise)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new_rowwise_data is undefined.

Suggested change
new_rowwise_data.detach().copy_(old_rowwise)
new_raw_data.detach().copy_(old_rowwise)

@WanZzzzzz WanZzzzzz closed this Feb 19, 2026
@WanZzzzzz
Copy link
Author

Hard to add signoff for previous commits. Reopened a new PR: #2691

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants